"""
__author__ = 'loong'
"""
import os
import shutil
from itertools import chain

from baye import resid
from baye.structures import CITY_MAX, TOOLS_MAX
from baye import structures
from baye import models
from baye.armtype import ArmType
from baye.res import Res
from baye.period import Period
from baye.util import pretty_print_xml, guard_parents_path, dec_raise_xml
from baye.tool import Tool
from baye.skill import Skill
from xml.etree import cElementTree as ET
from collections import OrderedDict
import json
from baye.parse_sav import GameSaveParser
from baye.image import Animation, ImageSet, Image, PImage
from baye import image
from baye import models_v1
from baye import models_v2

__version__ = '2.0.5'
__compact_version__ = '1.9.0'


def num_version(ver):
    return [int(p) for p in ver.split('.')]


class City:
    def __init__(self, name, x=0, y=0):
        self.name = name
        self.x = x
        self.y = y
        self.roads = b''


ASSETS = [
    ("MAIN_SPE", 1),
    ("FIGHT_TILE", 0),
    ("BING_PIC", 0),
    ("MAKER_SPE", 1),
    ("STEP_PIC", 0),
    ("STATE_PIC", 0),
    ("WEATHER_PIC", 0),
    ("NUM_PICID", 0),
    ("SPE_BACKPIC", 0),
    ("QIBING_SPE", 1),
    ("BUBING_SPE", 1),
    ("JIANBING_SPE", 1),
    ("SHUIJUN_SPE", 1),
    ("JIBING_SPE", 1),
    ("XUANBING_SPE", 1),
    ("SHUISHANG_SPE", 1),
    ("MSGBOX_PIC", 0),
    ("STACHG_SPE", 1),
    ("WEATHER_PIC1", 0),
    ("WEATHER_PIC2", 0),
    ("WEATHER_PIC3", 0),
    ("WEATHER_PIC4", 0),
    ("WEATHER_PIC5", 0),
    ("DAYS_PIC", 0),
    ("SPESTA_PIC", 0),
    ("FIRE_SPE", 1),
    ("WATER_SPE", 1),
    ("WOOD_SPE", 1),
    ("BUMP_SPE", 1),
    ("FENG_SPE", 1),
    ("LIUYAN_SPE", 1),
    ("YUAN_SPE", 1),
    ("ZHEN_SPE", 1),
    ("XJING_SPE", 1),
    ("MAIN_PIC", 0),
    ("YEAR_PIC", 0),
    ("SAVE_PIC", 0),
    ("CITY_PIC", 0),
    ("CITYMAP_TILE", 0),
    ("CITY_ICON", 0),
    ("CITY_POS_ICON", 0),
    ("MAPFACE_ICON", 0),
    ("TACTIC_ICON", 0),
    ("FIGHT_NOTE_ICON", 0),
    ("MAIN_ICON1", 1),
    ("MAIN_ICON2", 1),
    ("MAIN_ICON3", 1),
    ("MAIN_ICON4", 1),
    ("YEAR_ICON1", 1),
    ("YEAR_ICON2", 1),
    ("YEAR_ICON3", 1),
    ("YEAR_ICON4", 1),
]


def showEngineCfg(lib):
    dat = lib.resources[2].items[0x10-1]
    head = [
        (1, 'enableToolAttackRange'),
        (1, 'fixCityOffset'),
        (1, 'fixThewOverFlow'),
        (1, 'fixFoodOverFlow'),
        (1, 'fixOverFlow16'),
        (1, 'fixConsumeMoney'),
        (1, 'fixFightMoveOutRange'),
        (1, 'enable16bitConsumeMoney'),
        (1, 'enableScript'),
        (1, 'fixAlienateComsumeThew'),
        (1, 'disableSL'),
        (1, 'aiLevelUpSpeed'),
        (1, 'disableAgeGrow'),
        (1, 'enableCustomRatio'),

        (2, 'ratioOfArmsToLevel'),
        (1, 'ratioOfArmsToAge'),
        (1, 'ratioOfArmsToIQ'),
        (1, 'ratioOfArmsToForce'),

        (1, 'ratioOfAttToForce'),
        (1, 'ratioOfAttToIQ'),
        (1, 'ratioOfAttToAge'),

        (1, 'ratioOfDefenceToForce'),
        (1, 'ratioOfDefenceToIQ'),
        (1, 'ratioOfDefenceToAge'),
        (2, 'ratioOfFoodToArmsPerMouth'),
        (2, 'ratioOfFoodToArmsPerDay'),

        (1, 'armsPerDevotion'),
        (1, 'armsPerMoney'),
        (1, 'maxLevel'),
    ]
    for k, v in structures.read_struct(dat, 0, head).items():
        print('%-25s: %s' % (k, v))


class Config:
    output_scale = 2
    output_version = 1


class Lib:
    lib_version: int
    input_scale: int
    config: Config

    class meta:
        cities = '城池清单'
        map = '地图'
        tools = '道具清单'
        periods = '时期清单'
        skills = '技能清单'
        city_tag = '城池'
        city_name = '名称'
        save_name = '存档名称'

    # :type: [Res]
    def __init__(self, res_path, config: Config):
        if config.output_version not in (0, 1):
            raise Exception(f"不支持的输出lib版本 :{config.output_version}")
        self.config = config
        self.resources = OrderedDict()
        self.tools = []
        self.periods = []
        self.cities = []
        self.skills = []
        self.save_name = ''
        self.arm_types = [
            ArmType('骑兵', 0),
            ArmType('步兵', 1),
            ArmType('弓兵', 2),
            ArmType('水兵', 3),
            ArmType('极兵', 4),
            ArmType('玄兵', 5),
        ]
        self.res_path = res_path
        self.name_map = []
        self.city_map_width = 12
        self.city_map_height = 9

    def fix_model(self):
        if self.lib_version == 0:
            models.__dict__.update(models_v1.__dict__)
            structures.PERSON_MAX = 250
            structures.TOOLS_MAX = 250
            structures.TOOLS_QUEUE_MAX = 250
        elif self.lib_version == 1:
            models.__dict__.update(models_v2.__dict__)
        else:
            raise Exception(f"不支持的输出版本: {self.config.output_version}")

        image.input_scale = self.input_scale
        image.output_scale = self.config.output_scale

        print(f'lib_version    :{self.lib_version}')
        print(f'input_scale    :{self.input_scale}')
        print(f'output_version :{self.config.output_version}')
        print(f'output_scale   :{self.config.output_scale}')

        if self.config.output_scale <= 0:
            raise Exception("scale必须大于0")

    @classmethod
    def parse_bin(cls, dat, path='', config=None):
        lib = cls(path, config)
        lib.load_bin(dat)
        # showEngineCfg(lib)
        return lib

    def load_bin(self, dat):
        ver = structures.get_addr(dat, 17)
        if ver == 0xffffffff:
            self.lib_version = 0
            self.input_scale = 1
        else:
            self.lib_version = ((ver & 0xff00) >> 8) + 1
            self.input_scale = ver & 0xf
        self.fix_model()

        for idx, comment in resid.all_ids:
            addr = structures.get_addr(dat, idx)
            if addr == 0xffffffff:
                continue
            r = Res.parse_bin(idx, dat[addr:])
            r.index = idx
            r.comment = comment
            self.resources[idx] = r

        self.parse_patch()
        return self

    def parse_patch(self):
        res = self.resources[resid.IFACE_CONID]
        map_info = res.items[resid.IFACE_CONST_IDS.DirectP]
        city_count = 38
        if map_info and map_info[0] == 0x08:
            map_info = map_info[1:]
            city_count = map_info[0]
            self.city_map_width = map_info[1]
            self.city_map_height = map_info[2]

        res = self.resources[resid.CITY_NAME]
        for i, data in enumerate(res.items):
            if i == city_count:
                break
            name = structures.decode_gbk(data)
            self.cities.append(City(name))

        # city map
        res = self.resources[resid.IFACE_CONID]
        data = res.items[resid.IFACE_CONST_IDS.C_MAP]
        for i, cid in enumerate(data):
            if cid:
                x = i % self.city_map_width
                y = i // self.city_map_width
                self.cities[cid-1].x = x
                self.cities[cid-1].y = y

        data = res.items[resid.IFACE_CONST_IDS.dCityMapId]
        for i, mapid in enumerate(data):
            self.cities[i].mapid = mapid

        # road
        road_data = self.resources[resid.CITY_LINKR].items[0]

        for ci in range(city_count):
            self.cities[ci].roads = road_data[ci*16:ci*16+16]

        # tools
        tools = []

        data = self.resources[resid.GOODS_RESID].items[0]
        size = structures.sizeof(models.TOOL)
        for i in range(0, len(data), size):
            tool = Tool.parse_bin(i, data[i:])
            tools.append(tool)

        res = self.resources[resid.GOODS_NAME]
        for i, data in enumerate(res.items):
            if i == len(tools):  # need ?
                break
            name = structures.decode_gbk(data)
            tools[i].name = name
            self.tools = tools

        res = self.resources[resid.GOODS_INF]
        for i, data in enumerate(res.items):
            if i == len(tools):  # need ?
                break
            tools[i].description = structures.decode_gbk(data)

        res = self.resources[resid.SKL_NAMID]
        # skills
        for i, data in enumerate(res.items):
            from baye.skill import Skill
            name = structures.decode_gbk(data)
            if name:
                skill = Skill(i)
                skill.name = name
                self.skills.append(skill)
            else:
                break

        # periods
        for i in range(4):
            period = Period(self, i)
            period.load_lib()
            self.periods.append(period)

        data = self.resources[resid.IFACE_STRID].items[resid.IFACE_ITEM_IDS.dSaveFNam]
        self.save_name = structures.decode_gbk(data)

    @classmethod
    def from_res(cls, path, config):
        xmlfile = os.path.join(path, 'dat.xml')
        xml = open(xmlfile, 'rb').read()
        return cls.from_xml(ET.fromstring(xml), path, config=config)

    @classmethod
    def from_xml(cls, xml, path, config):
        lib = cls(path, config)
        lib.load_xml(xml)
        return lib

    @dec_raise_xml
    def load_xml(self, xml):
        v = xml.attrib.get('version')
        try:
            assert num_version(v) >= num_version(__compact_version__)
        except Exception:
            raise Exception('res 版本低')

        default_scale = 2 if num_version(v)[0] == 1 else 4
        self.input_scale = int(xml.attrib.get('scale', default_scale))
        self.lib_version = self.config.output_version
        self.fix_model()

        lib_node = None
        patch_node = None
        for e in xml:
            if e.tag == 'lib':
                lib_node = e
            elif e.tag == 'patch':
                patch_node = e

        self.feed_xml_lib(lib_node)
        self.feed_xml_patch(patch_node)
        return self

    def feed_xml_lib(self, xml):
        for res_node in xml:
            res = Res.from_xml(res_node)
            self.resources[res.id] = res

    def feed_xml_patch(self, xml):
        map_node = None
        cities_node = None
        tools_node = None
        skills_node = None
        periods_node = None

        for node in xml:
            if node.tag == self.meta.map:
                map_node = node
            elif node.tag == self.meta.tools:
                tools_node = node
            elif node.tag == self.meta.periods:
                periods_node = node
            elif node.tag == self.meta.skills:
                skills_node = node

        for node in map_node:
            if node.tag == self.meta.cities:
                cities_node = node

        city_names = []
        self.cities = []
        for i, n in enumerate(cities_node):
            if n.tag == self.meta.city_tag:
                name = n.attrib[self.meta.city_name]
                if name in city_names:
                    raise Exception('城池名称重复: {}'.format(name))
                x = int(n.attrib['x'])
                y = int(n.attrib['y'])
                city_names.append(name)
                city = City(name, x, y)
                city.roads = n.attrib['路径']
                city.mapid = int(n.attrib['战斗地图'])
                self.cities.append(city)

        if len(self.cities) > CITY_MAX:
            raise Exception('城池数量超限')

        def f(n):
            return 0 if n == '-' else (self.get_city_index_by_name(n) + 1)

        for city in self.cities:
            city.roads = bytes([f(n) for n in city.roads.split(',')])
            city.roads += b'\x00' * 8

        self.city_map_width = int(map_node.attrib['宽'])
        self.city_map_height = int(map_node.attrib['高'])

        tools = []
        tool_names = []
        for i, tool_node in enumerate(tools_node):
            if tool_node.tag == Tool.meta.tag:
                tool = Tool.from_xml(i, tool_node)
                if tool.name in tool_names:
                    raise Exception('道具名称重复: {}'.format(tool.name))
                tools.append(tool)
                tool_names.append(tool.name)

        for tool in tools:
            tool.load_att_range(self.tools_attack_range_path())
        self.tools = tools
        if len(tools) > TOOLS_MAX:
            raise Exception('道具型号太多, 最多支持{}个'.format(TOOLS_MAX))

        skills = []
        for i, node in enumerate(skills_node):
            if node.tag == Skill.meta.tag:
                skill = Skill(i).load_xml(node)
                skills.append(skill)
        self.skills = skills

        periods = []
        for i, pnode in enumerate(periods_node):
            if i >= 4:
                raise Exception('不能超过4个时期')

            p = Period.from_xml(pnode, self, i)
            periods.append(p)

        if len(periods) != 4:
            raise Exception('必须有4个时期')
        self.periods = periods

        self.save_name = xml.attrib.get(self.meta.save_name)

    def patch(self):
        map_info = b'\x08'
        map_info += structures.write_int(len(self.cities), 1)
        map_info += structures.write_int(self.city_map_width, 1)
        map_info += structures.write_int(self.city_map_height, 1)
        res = self.resources[resid.IFACE_CONID]
        res.items[resid.IFACE_CONST_IDS.DirectP] = map_info

        # city name
        res = Res(resid.CITY_NAME)
        res.items = [structures.encode_gbk(c.name) for c in self.cities]
        self.resources[res.id] = res

        # city map
        city_map = [[0 for _ in range(self.city_map_width)] for _ in range(self.city_map_height)]

        for i, city in enumerate(self.cities):
            city_map[city.y][city.x] = i + 1
        data = b''.join([bytes(line) for line in city_map])

        res = self.resources[resid.IFACE_CONID]
        res.items[resid.IFACE_CONST_IDS.C_MAP] = data
        # map end
        res.items[resid.IFACE_CONST_IDS.dCityPos] = bytes(chain(*[(c.x, c.y) for c in self.cities]))
        res.items[resid.IFACE_CONST_IDS.dCityMapId] = bytes([c.mapid for c in self.cities])

        # road
        road_data = b''

        for city in self.cities:
            road_data += city.roads

        self.resources[resid.CITY_LINKR].items = [road_data]

        data = b''
        for tool in self.tools:
            data += tool.bin()
        res = Res(resid.GOODS_RESID)
        res.items = [data]
        self.resources[resid.GOODS_RESID] = res

        res = Res(resid.GOODS_NAME)
        res.items = [structures.encode_gbk(t.name) for t in self.tools]
        self.resources[resid.GOODS_NAME] = res

        res = Res(resid.GOODS_INF)
        res.items = [structures.encode_gbk(t.description) for t in self.tools]
        self.resources[resid.GOODS_INF] = res

        for period in self.periods:
            period.patch()

        # engine config, enable city offset fix
        res = self.resources[resid.IFACE_CONID]
        if len(res.items) == 15:
            res.items.append(b'\x01')

        if self.save_name:
            bname = self.save_name.encode()
            if len(bname) > 19:
                raise Exception('存档名称太长')
            self.resources[resid.IFACE_STRID].items[resid.IFACE_ITEM_IDS.dSaveFNam] = bname

    def check(self):
        for period in self.periods:
            period.check()

    def bin(self):
        indexes = [4294967295 for _ in range(116)]
        offset = len(indexes) * 4
        if self.lib_version != 0:
            indexes[16] = ((self.lib_version - 1) << 8) + self.config.output_scale

        payload = b''
        for res in self.resources.values():
            if res is not None:
                bin = res.bin()
                indexes[res.id - 1] = offset
                offset += len(bin)
                payload += bin

        indexes_bin = b''
        for addr in indexes:
            binaddr = structures.write_int(addr, 4)
            indexes_bin += binaddr
        return indexes_bin + payload

    def xml(self):
        root = ET.Element('root')
        root.attrib['version'] = __version__
        root.attrib['scale'] = str(self.config.output_scale)
        lib = ET.Element('lib')
        for res in self.resources.values():
            lib.append(res.xml())

        lib.attrib['doc'] = 'lib原始数据'

        # patch
        patch = ET.Element('patch')

        # map
        map_node = ET.Element(self.meta.map)
        patch.append(map_node)
        map_node.attrib['宽'] = str(self.city_map_width)
        map_node.attrib['高'] = str(self.city_map_height)

        # cities
        cities_node = ET.Element(self.meta.cities)
        for city in self.cities:
            i = ET.Element(self.meta.city_tag)
            i.attrib[self.meta.city_name] = city.name
            i.attrib['x'] = str(city.x)
            i.attrib['y'] = str(city.y)
            i.attrib['路径'] = ','.join([self.cities[i-1].name if i else '-' for i in city.roads[:8]])
            i.attrib['战斗地图'] = str(city.mapid)
            cities_node.append(i)
        map_node.append(cities_node)

        # tools
        tools_node = ET.Element(self.meta.tools)
        for tool in self.tools:
            tools_node.append(tool.xml())
        patch.append(tools_node)

        # skills
        skills_node = ET.Element(self.meta.skills)
        skills_node.attrib['备注'] = '技能暂不支持修改, 敬请期待后续版本'
        for skill in self.skills:
            skills_node.append(skill.xml())
        patch.append(skills_node)

        # periods
        periods_node = ET.Element(self.meta.periods)
        for period in self.periods:
            periods_node.append(period.xml())
        patch.append(periods_node)

        patch.attrib[self.meta.save_name] = self.save_name

        root.append(patch)

        root.append(lib)
        return root

    def load_name_map(self):
        try:
            self.name_map = json.load(open(self.name_map_path()))
        except IOError:
            print('name map file is missing')

    def dump_name_map(self):
        name_map = []
        for pid in range(len(self.periods)):
            period = self.periods[pid]
            name_map.append(period.create_name_map())
        map_s = json.dumps(name_map)
        guard_parents_path(self.name_map_path())
        open(self.name_map_path(), 'w').write(map_s)

    def name_map_path(self):
        return os.path.join(self.resource_path(), 'name_map.json')

    def resource_path(self):
        return os.path.abspath(self.res_path)

    def avatar_path(self, period):
        parent_path = self.resource_path()
        return os.path.join(parent_path, '将领头像', '{}'.format(period.index+1))

    def tools_attack_range_path(self):
        return os.path.join(self.resource_path(), '道具攻击范围')

    def assets_path(self):
        return os.path.join(self.resource_path(), 'assets')

    def save_avatars(self):
        for period in self.periods:
            period.save_avatars(self.avatar_path(period))

    def patch_avatars(self):
        for period in self.periods:
            period.patch_avatars(self.avatar_path(period))

    def xmlText(self):
        return pretty_print_xml(self.xml())

    def get_tool_by_name(self, name):
        for t in self.tools:
            if name == t.name:
                return t
        raise Exception('不存在道具 {}'.format(name))

    def get_skill_index_by_name(self, name):
        if not name:
            return -1
        for t in self.skills:
            if name == t.name:
                return t.index
        raise Exception('不存在技能 {}'.format(name))

    def get_city_index_by_name(self, name):
        if not name:
            return -1
        for i, ct in enumerate(self.cities):
            if name == ct.name:
                return i
        raise Exception('不存在城池 {}'.format(name))

    def get_arm_type_by_name(self, name):
        for t in self.arm_types:
            if name == t.name:
                return t
        raise Exception('不存在兵种 {}'.format(name))

    def get_old_person_index(self, period, name):
        try:
            return self.name_map[period][name]
        except Exception:
            return None

    def script_path(self):
        return os.path.join(self.resource_path(), 'script.js')

    def save_script(self):
        res = self.resources.get(resid.ENGINE_SCRIPT)
        scr = (res and res.raw or b'').rstrip(b'\x00')
        if scr:
            open(self.script_path(), 'wb').write(scr)

    def patch_script(self):
        if not os.path.exists(self.script_path()):
            self.resources[resid.ENGINE_SCRIPT] = None
            return

        script = open(self.script_path(), 'rb').read()
        if script:
            res = Res(resid.ENGINE_SCRIPT)
            res.raw = script + b'\x00'
            self.resources[resid.ENGINE_SCRIPT] = res

    def iter_assets(self):
        for res, typ in ASSETS:
            rid = getattr(resid, res)
            path = os.path.join(self.assets_path(),  '%03d_%s.assets' % (rid, res))
            yield rid, path, typ

    def decode_asset(self, rid, respath, typ):
        for i, data in enumerate(self.resources[rid].items):
            path = os.path.join(respath, str(i))
            if rid == 3:
                # i > 0 for custom fonts
                typ = 1 if i == 0 else 0
            if typ == 1:
                Animation(data).save(path)
            elif typ == 0:
                ImageSet(data).save(path)

    def decode_assets(self):
        for rid, path, typ in self.iter_assets():
            self.decode_asset(rid, path, typ)

    def encode_asset(self, rid, respath):
        i = 0
        self.resources[rid].items = []
        while True:
            path = os.path.join(respath, str(i))
            i += 1
            if not os.path.isdir(path):
                break
            xml = ET.fromstring(open(os.path.join(path, 'meta.xml')).read())
            if xml.tag == 'spe':
                data = Animation().load_assets(path).bin()
            elif xml.tag == 'image':
                data = ImageSet().load_assets(path).bin()
            else:
                data = None
            assert data is not None, 'encode {} failed'.format(path)
            self.resources[rid].items.append(data)

    def encode_assets(self):
        for rid, path, _ in self.iter_assets():
            self.encode_asset(rid, path)

    def save_city_map(self):
        w = self.city_map_width
        h = self.city_map_height
        sz = 16 * self.config.output_scale
        pImage = PImage.new('1', (w*sz, h*sz), 1)
        imageSet = ImageSet(self.resources[resid.CITYMAP_TILE].items[0])
        for row in range(h):
            for col in range(w):
                ind = w * row + col
                img = imageSet.image(ind)
                tile = img.to_image()
                pImage.paste(tile, (col*sz, row*sz))
        pImage.save(os.path.join(self.resource_path(), 'city-map.bmp'))

    def patch_city_map(self):
        sz = 16 * self.config.output_scale
        pImage = PImage.open(os.path.join(self.resource_path(), 'city-map.bmp'))
        w = self.city_map_width
        h = self.city_map_height
        pImage = pImage.resize((w*sz, h*sz))
        images = []
        for row in range(h):
            for col in range(w):
                tile = pImage.crop((col*sz, row*sz, (col+1)*sz, (row+1)*sz))
                images.append(Image.from_image(tile, sz, sz))
        imageSet = ImageSet().load_images(sz, sz, images)
        self.resources[resid.CITYMAP_TILE].items[0] = imageSet.bin()

    def decompile(self):
        xml = self.xmlText()
        if os.path.exists(self.res_path):
            shutil.rmtree(self.res_path)
        os.makedirs(self.res_path, exist_ok=True)
        xmlfile = os.path.join(self.res_path, 'dat.xml')
        open(xmlfile, 'wb').write(xml)
        self.dump_name_map()
        for tool in self.tools:
            tool.dump_att_range(self.tools_attack_range_path())
        self.save_avatars()
        self.save_script()
        self.decode_assets()
        self.save_city_map()

    def compile(self, bin_path):
        self.load_name_map()
        self.patch()
        self.check()
        self.patch_avatars()
        self.patch_script()
        self.encode_assets()
        self.patch_city_map()
        open(bin_path, 'wb').write(self.bin())

    def parse_sav(self, sav0, sav1):
        return GameSaveParser(sav0, sav1, ishex=True).parse(self)
